{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example of loading a custom tree model into SHAP\n",
    "\n",
    "This notebook shows how to pass a custom tree ensemble model into SHAP for explanation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import graphviz\n",
    "import numpy as np\n",
    "import scipy\n",
    "import sklearn\n",
    "\n",
    "import shap"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple regression tree model\n",
    "\n",
    "Here we define a simple regression tree and then load it into SHAP as a custom model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DecisionTreeRegressor(max_depth=2)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">DecisionTreeRegressor</label><div class=\"sk-toggleable__content\"><pre>DecisionTreeRegressor(max_depth=2)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "DecisionTreeRegressor(max_depth=2)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, y = shap.datasets.adult()\n",
    "\n",
    "orig_model = sklearn.tree.DecisionTreeRegressor(max_depth=2)\n",
    "orig_model.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.43.0 (0)\n",
       " -->\n",
       "<!-- Title: Tree Pages: 1 -->\n",
       "<svg width=\"740pt\" height=\"261pt\"\n",
       " viewBox=\"0.00 0.00 740.00 261.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 257)\">\n",
       "<title>Tree</title>\n",
       "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-257 736,-257 736,4 -4,4\"/>\n",
       "<!-- 0 -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>0</title>\n",
       "<path fill=\"#fae5d6\" stroke=\"black\" d=\"M450,-253C450,-253 300,-253 300,-253 294,-253 288,-247 288,-241 288,-241 288,-201 288,-201 288,-195 294,-189 300,-189 300,-189 450,-189 450,-189 456,-189 462,-195 462,-201 462,-201 462,-241 462,-241 462,-247 456,-253 450,-253\"/>\n",
       "<text text-anchor=\"start\" x=\"344.5\" y=\"-238.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">x</text>\n",
       "<text text-anchor=\"start\" x=\"353.5\" y=\"-238.8\" font-family=\"Helvetica,sans-Serif\" baseline-shift=\"sub\" font-size=\"14.00\">5</text>\n",
       "<text text-anchor=\"start\" x=\"361.5\" y=\"-238.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\"> ≤ 3.5</text>\n",
       "<text text-anchor=\"start\" x=\"296\" y=\"-224.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.183</text>\n",
       "<text text-anchor=\"start\" x=\"313\" y=\"-210.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 32561</text>\n",
       "<text text-anchor=\"start\" x=\"325\" y=\"-196.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 0.241</text>\n",
       "</g>\n",
       "<!-- 1 -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>1</title>\n",
       "<path fill=\"#fffdfb\" stroke=\"black\" d=\"M354,-153C354,-153 204,-153 204,-153 198,-153 192,-147 192,-141 192,-141 192,-101 192,-101 192,-95 198,-89 204,-89 204,-89 354,-89 354,-89 360,-89 366,-95 366,-101 366,-101 366,-141 366,-141 366,-147 360,-153 354,-153\"/>\n",
       "<text text-anchor=\"start\" x=\"235\" y=\"-138.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">x</text>\n",
       "<text text-anchor=\"start\" x=\"244\" y=\"-138.8\" font-family=\"Helvetica,sans-Serif\" baseline-shift=\"sub\" font-size=\"14.00\">8</text>\n",
       "<text text-anchor=\"start\" x=\"252\" y=\"-138.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\"> ≤ 7073.5</text>\n",
       "<text text-anchor=\"start\" x=\"200\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.062</text>\n",
       "<text text-anchor=\"start\" x=\"217\" y=\"-110.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 17800</text>\n",
       "<text text-anchor=\"start\" x=\"229\" y=\"-96.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 0.066</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;1 -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>0&#45;&gt;1</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M344.62,-188.99C335.78,-179.97 326.01,-169.99 316.77,-160.56\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"319,-157.84 309.51,-153.14 314,-162.74 319,-157.84\"/>\n",
       "<text text-anchor=\"middle\" x=\"309.25\" y=\"-174.44\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\n",
       "</g>\n",
       "<!-- 4 -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>4</title>\n",
       "<path fill=\"#f4c8a8\" stroke=\"black\" d=\"M546,-153C546,-153 396,-153 396,-153 390,-153 384,-147 384,-141 384,-141 384,-101 384,-101 384,-95 390,-89 396,-89 396,-89 546,-89 546,-89 552,-89 558,-95 558,-101 558,-101 558,-141 558,-141 558,-147 552,-153 546,-153\"/>\n",
       "<text text-anchor=\"start\" x=\"436\" y=\"-138.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">x</text>\n",
       "<text text-anchor=\"start\" x=\"445\" y=\"-138.8\" font-family=\"Helvetica,sans-Serif\" baseline-shift=\"sub\" font-size=\"14.00\">2</text>\n",
       "<text text-anchor=\"start\" x=\"453\" y=\"-138.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\"> ≤ 12.5</text>\n",
       "<text text-anchor=\"start\" x=\"392\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.248</text>\n",
       "<text text-anchor=\"start\" x=\"409\" y=\"-110.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 14761</text>\n",
       "<text text-anchor=\"start\" x=\"421\" y=\"-96.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 0.451</text>\n",
       "</g>\n",
       "<!-- 0&#45;&gt;4 -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>0&#45;&gt;4</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M405.38,-188.99C414.22,-179.97 423.99,-169.99 433.23,-160.56\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"436,-162.74 440.49,-153.14 431,-157.84 436,-162.74\"/>\n",
       "<text text-anchor=\"middle\" x=\"440.75\" y=\"-174.44\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\n",
       "</g>\n",
       "<!-- 2 -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>2</title>\n",
       "<path fill=\"#ffffff\" stroke=\"black\" d=\"M162,-53C162,-53 12,-53 12,-53 6,-53 0,-47 0,-41 0,-41 0,-12 0,-12 0,-6 6,0 12,0 12,0 162,0 162,0 168,0 174,-6 174,-12 174,-12 174,-41 174,-41 174,-47 168,-53 162,-53\"/>\n",
       "<text text-anchor=\"start\" x=\"8\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.047</text>\n",
       "<text text-anchor=\"start\" x=\"25\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 17482</text>\n",
       "<text text-anchor=\"start\" x=\"41.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 0.05</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;2 -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>1&#45;&gt;2</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M214.53,-88.94C193.47,-78.79 170.12,-67.54 149.22,-57.48\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"150.65,-54.28 140.12,-53.09 147.61,-60.58 150.65,-54.28\"/>\n",
       "</g>\n",
       "<!-- 3 -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>3</title>\n",
       "<path fill=\"#e58139\" stroke=\"black\" d=\"M354,-53C354,-53 204,-53 204,-53 198,-53 192,-47 192,-41 192,-41 192,-12 192,-12 192,-6 198,0 204,0 204,0 354,0 354,0 360,0 366,-6 366,-12 366,-12 366,-41 366,-41 366,-47 360,-53 354,-53\"/>\n",
       "<text text-anchor=\"start\" x=\"200\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.036</text>\n",
       "<text text-anchor=\"start\" x=\"226\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 318</text>\n",
       "<text text-anchor=\"start\" x=\"229\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 0.962</text>\n",
       "</g>\n",
       "<!-- 1&#45;&gt;3 -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>1&#45;&gt;3</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M279,-88.94C279,-80.66 279,-71.64 279,-63.13\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"282.5,-63.09 279,-53.09 275.5,-63.09 282.5,-63.09\"/>\n",
       "</g>\n",
       "<!-- 5 -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>5</title>\n",
       "<path fill=\"#f7d8c1\" stroke=\"black\" d=\"M546,-53C546,-53 396,-53 396,-53 390,-53 384,-47 384,-41 384,-41 384,-12 384,-12 384,-6 390,0 396,0 396,0 546,0 546,0 552,0 558,-6 558,-12 558,-12 558,-41 558,-41 558,-47 552,-53 546,-53\"/>\n",
       "<text text-anchor=\"start\" x=\"392\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.223</text>\n",
       "<text text-anchor=\"start\" x=\"409\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 10329</text>\n",
       "<text text-anchor=\"start\" x=\"421\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 0.335</text>\n",
       "</g>\n",
       "<!-- 4&#45;&gt;5 -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>4&#45;&gt;5</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M471,-88.94C471,-80.66 471,-71.64 471,-63.13\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"474.5,-63.09 471,-53.09 467.5,-63.09 474.5,-63.09\"/>\n",
       "</g>\n",
       "<!-- 6 -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>6</title>\n",
       "<path fill=\"#eca26d\" stroke=\"black\" d=\"M720,-53C720,-53 588,-53 588,-53 582,-53 576,-47 576,-41 576,-41 576,-12 576,-12 576,-6 582,0 588,0 588,0 720,0 720,0 726,0 732,-6 732,-12 732,-12 732,-41 732,-41 732,-47 726,-53 720,-53\"/>\n",
       "<text text-anchor=\"start\" x=\"584\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">squared_error = 0.2</text>\n",
       "<text text-anchor=\"start\" x=\"596.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 4432</text>\n",
       "<text text-anchor=\"start\" x=\"604\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = 0.724</text>\n",
       "</g>\n",
       "<!-- 4&#45;&gt;6 -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>4&#45;&gt;6</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M532.45,-88.94C552.44,-78.84 574.58,-67.65 594.43,-57.61\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"596.03,-60.73 603.37,-53.09 592.87,-54.48 596.03,-60.73\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.sources.Source at 0x7fab14963e20>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dot_data = sklearn.tree.export_graphviz(orig_model, out_file=None, filled=True, rounded=True, special_characters=True)\n",
    "graph = graphviz.Source(dot_data)\n",
    "graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For more information what these attributes mean exactly, see the [scikit-learn documentation](https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html#tree-structure)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     children_left [ 1  2 -1 -1  5 -1 -1]\n",
      "    children_right [ 4  3 -1 -1  6 -1 -1]\n",
      "  children_default [ 4  3 -1 -1  6 -1 -1]\n",
      "          features [ 5  8 -2 -2  2 -2 -2]\n",
      "        thresholds [ 3.5000e+00  7.0735e+03 -2.0000e+00 -2.0000e+00  1.2500e+01 -2.0000e+00\n",
      " -2.0000e+00]\n",
      "            values [[0.241]\n",
      " [0.066]\n",
      " [0.05 ]\n",
      " [0.962]\n",
      " [0.451]\n",
      " [0.335]\n",
      " [0.724]]\n",
      "node_sample_weight [32561. 17800. 17482.   318. 14761. 10329.  4432.]\n"
     ]
    }
   ],
   "source": [
    "# extract the arrays that define the tree\n",
    "children_left = orig_model.tree_.children_left\n",
    "children_right = orig_model.tree_.children_right\n",
    "children_default = children_right.copy()  # because sklearn does not use missing values\n",
    "features = orig_model.tree_.feature\n",
    "thresholds = orig_model.tree_.threshold\n",
    "values = orig_model.tree_.value.reshape(orig_model.tree_.value.shape[0], 1)\n",
    "node_sample_weight = orig_model.tree_.weighted_n_node_samples\n",
    "\n",
    "print(\"     children_left\", children_left)  # note that negative children values mean this is a leaf node\n",
    "print(\"    children_right\", children_right)\n",
    "print(\"  children_default\", children_default)\n",
    "print(\"          features\", features)\n",
    "print(\"        thresholds\", thresholds.round(3))  # -2 means the node is a leaf node\n",
    "print(\"            values\", values.round(3))\n",
    "print(\"node_sample_weight\", node_sample_weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define a custom tree model\n",
    "tree_dict = {\n",
    "    \"children_left\": children_left,\n",
    "    \"children_right\": children_right,\n",
    "    \"children_default\": children_default,\n",
    "    \"features\": features,\n",
    "    \"thresholds\": thresholds,\n",
    "    \"values\": values,\n",
    "    \"node_sample_weight\": node_sample_weight,\n",
    "}\n",
    "model = {\"trees\": [tree_dict]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "explainer = shap.TreeExplainer(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make sure that the ingested SHAP model (a TreeEnsemble object) makes the\n",
    "# same predictions as the original model\n",
    "assert np.abs(explainer.model.predict(X) - orig_model.predict(X)).max() < 1e-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make sure the SHAP values sum up to the model output (this is the local accuracy property)\n",
    "assert np.abs(explainer.expected_value + explainer.shap_values(X).sum(1) - orig_model.predict(X)).max() < 1e-4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple GBM classification model (with 2 trees)\n",
    "\n",
    "Here we define a simple gradient-boosting classifier and then load it into SHAP as a custom model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-2 {color: black;}#sk-container-id-2 pre{padding: 0;}#sk-container-id-2 div.sk-toggleable {background-color: white;}#sk-container-id-2 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-2 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-2 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-2 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-2 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-2 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-2 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-2 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-2 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-2 div.sk-item {position: relative;z-index: 1;}#sk-container-id-2 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-2 div.sk-item::before, #sk-container-id-2 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-2 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-2 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-2 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-2 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-2 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-2 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-2 div.sk-label-container {text-align: center;}#sk-container-id-2 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-2 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GradientBoostingClassifier(n_estimators=2)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GradientBoostingClassifier</label><div class=\"sk-toggleable__content\"><pre>GradientBoostingClassifier(n_estimators=2)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "GradientBoostingClassifier(n_estimators=2)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X2, y2 = shap.datasets.adult()\n",
    "orig_model2 = sklearn.ensemble.GradientBoostingClassifier(n_estimators=2)\n",
    "orig_model2.fit(X2, y2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pull the info of the first tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     children_left1 [ 1  2  3 -1 -1  6 -1 -1  9 10 -1 -1 13 -1 -1]\n",
      "    children_right1 [ 8  5  4 -1 -1  7 -1 -1 12 11 -1 -1 14 -1 -1]\n",
      "  children_default1 [ 8  5  4 -1 -1  7 -1 -1 12 11 -1 -1 14 -1 -1]\n",
      "          features1 [ 5  8  2 -2 -2  0 -2 -2  2  8 -2 -2  8 -2 -2]\n",
      "        thresholds1 [ 3.5000e+00  7.0735e+03  1.2500e+01 -2.0000e+00 -2.0000e+00  2.0500e+01\n",
      " -2.0000e+00 -2.0000e+00  1.2500e+01  5.0955e+03 -2.0000e+00 -2.0000e+00\n",
      "  5.0955e+03 -2.0000e+00 -2.0000e+00]\n",
      "            values1 [[-0.   ]\n",
      " [-0.175]\n",
      " [-0.191]\n",
      " [-1.177]\n",
      " [-0.503]\n",
      " [ 0.721]\n",
      " [-0.223]\n",
      " [ 4.013]\n",
      " [ 0.211]\n",
      " [ 0.094]\n",
      " [ 0.325]\n",
      " [ 4.048]\n",
      " [ 0.483]\n",
      " [ 2.372]\n",
      " [ 4.128]]\n",
      "node_sample_weight1 [3.2561e+04 1.7800e+04 1.7482e+04 1.4036e+04 3.4460e+03 3.1800e+02\n",
      " 5.0000e+00 3.1300e+02 1.4761e+04 1.0329e+04 9.8070e+03 5.2200e+02\n",
      " 4.4320e+03 3.7540e+03 6.7800e+02]\n"
     ]
    }
   ],
   "source": [
    "tree_tmp = orig_model2.estimators_[0][0].tree_\n",
    "\n",
    "# extract the arrays that define the tree\n",
    "children_left1 = tree_tmp.children_left\n",
    "children_right1 = tree_tmp.children_right\n",
    "children_default1 = children_right1.copy()  # because sklearn does not use missing values\n",
    "features1 = tree_tmp.feature\n",
    "thresholds1 = tree_tmp.threshold\n",
    "values1 = tree_tmp.value.reshape(tree_tmp.value.shape[0], 1)\n",
    "node_sample_weight1 = tree_tmp.weighted_n_node_samples\n",
    "\n",
    "print(\"     children_left1\", children_left1)  # note that negative children values mean this is a leaf node\n",
    "print(\"    children_right1\", children_right1)\n",
    "print(\"  children_default1\", children_default1)\n",
    "print(\"          features1\", features1)\n",
    "print(\"        thresholds1\", thresholds1.round(3))\n",
    "print(\"            values1\", values1.round(3))\n",
    "print(\"node_sample_weight1\", node_sample_weight1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pull the info of the second tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     children_left2 [ 1  2  3 -1 -1  6 -1 -1  9 10 -1 -1 13 -1 -1]\n",
      "    children_right2 [ 8  5  4 -1 -1  7 -1 -1 12 11 -1 -1 14 -1 -1]\n",
      "  children_default2 [ 8  5  4 -1 -1  7 -1 -1 12 11 -1 -1 14 -1 -1]\n",
      "          features2 [ 5  8  2 -2 -2  0 -2 -2  2  8 -2 -2  8 -2 -2]\n",
      "        thresholds2 [ 3.5000e+00  7.0735e+03  1.3500e+01 -2.0000e+00 -2.0000e+00  2.0500e+01\n",
      " -2.0000e+00 -2.0000e+00  1.2500e+01  5.0955e+03 -2.0000e+00 -2.0000e+00\n",
      "  5.0955e+03 -2.0000e+00 -2.0000e+00]\n",
      "            values2 [[-1.000e-03]\n",
      " [-1.580e-01]\n",
      " [-1.720e-01]\n",
      " [-1.062e+00]\n",
      " [ 1.360e-01]\n",
      " [ 6.420e-01]\n",
      " [-2.030e-01]\n",
      " [ 2.993e+00]\n",
      " [ 1.880e-01]\n",
      " [ 8.400e-02]\n",
      " [ 2.870e-01]\n",
      " [ 3.015e+00]\n",
      " [ 4.310e-01]\n",
      " [ 1.895e+00]\n",
      " [ 3.066e+00]]\n",
      "node_sample_weight2 [3.2561e+04 1.7800e+04 1.7482e+04 1.6560e+04 9.2200e+02 3.1800e+02\n",
      " 5.0000e+00 3.1300e+02 1.4761e+04 1.0329e+04 9.8070e+03 5.2200e+02\n",
      " 4.4320e+03 3.7540e+03 6.7800e+02]\n"
     ]
    }
   ],
   "source": [
    "tree_tmp = orig_model2.estimators_[1][0].tree_\n",
    "\n",
    "# extract the arrays that define the tree\n",
    "children_left2 = tree_tmp.children_left\n",
    "children_right2 = tree_tmp.children_right\n",
    "children_default2 = children_right2.copy()  # because sklearn does not use missing values\n",
    "features2 = tree_tmp.feature\n",
    "thresholds2 = tree_tmp.threshold\n",
    "values2 = tree_tmp.value.reshape(tree_tmp.value.shape[0], 1)\n",
    "node_sample_weight2 = tree_tmp.weighted_n_node_samples\n",
    "\n",
    "print(\"     children_left2\", children_left2)  # note that negative children values mean this is a leaf node\n",
    "print(\"    children_right2\", children_right2)\n",
    "print(\"  children_default2\", children_default2)\n",
    "print(\"          features2\", features2)\n",
    "print(\"        thresholds2\", thresholds2.round(3))\n",
    "print(\"            values2\", values2.round(3))\n",
    "print(\"node_sample_weight2\", node_sample_weight2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a list of SHAP Trees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define a custom tree model\n",
    "tree_dicts = [\n",
    "    {\n",
    "        \"children_left\": children_left1,\n",
    "        \"children_right\": children_right1,\n",
    "        \"children_default\": children_default1,\n",
    "        \"features\": features1,\n",
    "        \"thresholds\": thresholds1,\n",
    "        \"values\": values1 * orig_model2.learning_rate,\n",
    "        \"node_sample_weight\": node_sample_weight1,\n",
    "    },\n",
    "    {\n",
    "        \"children_left\": children_left2,\n",
    "        \"children_right\": children_right2,\n",
    "        \"children_default\": children_default2,\n",
    "        \"features\": features2,\n",
    "        \"thresholds\": thresholds2,\n",
    "        \"values\": values2 * orig_model2.learning_rate,\n",
    "        \"node_sample_weight\": node_sample_weight2,\n",
    "    },\n",
    "]\n",
    "model2 = {\n",
    "    \"trees\": tree_dicts,\n",
    "    \"base_offset\": scipy.special.logit(orig_model2.init_.class_prior_[1]),\n",
    "    \"tree_output\": \"log_odds\",\n",
    "    \"objective\": \"binary_crossentropy\",\n",
    "    \"input_dtype\": np.float32,  # this is what type the model uses the input feature data\n",
    "    \"internal_dtype\": np.float64,  # this is what type the model uses for values and thresholds\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Explain the custom model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# build a background dataset for us to use based on people near a 0.95 cutoff\n",
    "vs = np.abs(orig_model2.predict_proba(X2)[:, 1] - 0.95)\n",
    "inds = np.argsort(vs)\n",
    "inds = inds[:200]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# build an explainer that explains the probability output of the model\n",
    "explainer2 = shap.TreeExplainer(\n",
    "    model2,\n",
    "    X2.iloc[inds, :],\n",
    "    feature_perturbation=\"interventional\",\n",
    "    model_output=\"probability\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make sure that the ingested SHAP model (a TreeEnsemble object) makes the\n",
    "# same predictions as the original model\n",
    "assert np.abs(explainer2.model.predict(X2, output=\"probability\") - orig_model2.predict_proba(X2)[:, 1]).max() < 1e-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make sure the sum of the SHAP values equals the model output\n",
    "shap_sum = explainer2.expected_value + explainer2.shap_values(X2.iloc[:, :]).sum(1)\n",
    "assert np.abs(shap_sum - orig_model2.predict_proba(X2)[:, 1]).max() < 1e-4"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
